-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FSQ implementation #74
Add FSQ implementation #74
Conversation
nice! give it a test drive with the cifar script in the examples folder @kashif ^ |
Still not 100% sure this is correct, but initial results seem promising. vq :: rec loss: 0.114 | cmt loss: 0.001 | active %: 21.094
fsq :: rec loss: 0.111 | active %: 58.333 Not parameter matched, so losses aren't really representative, but interesting to see the higher codebook usage. |
do you have the training curves for each? |
edit: removed the images that were here because the test was broken |
Probably isn't representative of FSQ performance in general, but it does seem to be working on some level at least ^^ |
Ok wait, strike that. Reading the paper a bit more closely, the offsets should be flipped relative to what I put there initially. Doing so nets a much higher active %. |
return z + (zhat - z).detach() | ||
|
||
|
||
class FSQ(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for porting this! Do you mind if we link this repo in the next version and our own public code release?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMK if you are also planning to update the README and I can send some figs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please do! and i believe @lucidrains and @sekstini will appreciate it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, go ahead 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMK if you are also planning to update the README and I can send some figs.
That would be great 🙏
Yeah, the test models are both shallow, and tiny at ~10k parameters. Don't expect this to realistically reflect performance at all. This is the FSQ network.
Will do 👍 |
Makes sense! Yeah one thing with FSQ is that because your codebook is fixed, you need some decent number of layers before and after the quantizer. Usually this is the case (eg the VQ-GANs people train are fairly deep). Anyway, I think the results you see are what we expect! Good stuff. For the figures, I revisited your README and it seems it is usually mostly the code, so maybe simply using the cube (PDF here) would be enough or already too much. Your call :) In the README of our upcoming mini-repo, I also added the table, but again, this might be out of place in your README:
source:
|
lgtm! releasing thank you @sekstini ! |
@lucidrains Oh, apparently I pointed this as the fsq branch, so you might need to merge it into master |
@sekstini yup, no problem, thank you! 🙏 |
Thanks everyone! I'll add a link to this repo in our next revision Added a link to the official README now. |
Appreciate this quick port and example code. Just thought I would add that checking reconstructions from the FashionMNIST example appear reasonable However a bit of surprise noticing it fails if switching back to MNIST Changing levels/seeds didn't help, so perhaps as @fab-jul mentioned it's just a case of needing more layers before/after quantizer for some cases. |
@dribnet Interesting. Not sure why we would see this particular failure mode here, but I made a toy example of residual vector quantization where it's working: https://gist.github.com/sekstini/7f089f71d4b975ec8bde37d878b514d0. |
@dribnet adding more layers seems testable |
LFQ https://arxiv.org/abs/2310.05737 looks simliar? |
nevermind, it is slightly different, will be adding |
It‘s FSQ with levels = 2 plus an entropy maximization loss, IIUC. |
yes, FSQ generalizes it, save for the entropy loss. no ablation of the entropy loss, so unsure how necessary it is, but i'll add it. perhaps that will be their contribution |
@fab-jul congrats either way! this could be a winner |
@lucidrains thanks! for us, levels = 2 was suboptimal but I can see how an entropy maximization would fix that. I wonder if that also improves FSQ (although ofc the goal of our paper was to get rid of all aux losses haha) |
Funny you mention it, as I'm experimenting with a vocoder based on this idea in this very moment ^^ |
@sekstini are you seeing good results? thinking about a soundstream variation with multi-headed LFQ |
Not really, but definitely not because of FSQ. I have exclusively been toying with various parallel encoding schemes, which I'm guessing are difficult for the model to learn, and I suspect residual quantization would work a lot better. |
@sekstini ah got it, thanks! yea, to do residual, i think the codes will need to be scaled down an order of magnitude, or inputs scaled up (cube within cubes, if you can imagine it), but i haven't worked it out. it is probably a 3 month research project. somebody will def end up trying it.. |
ok it is done, integrated in soundstream over here |
may be of interest! lucidrains/magvit2-pytorch#4 |
@mueller-franzes You may find this comment by Fabian interesting.
Makes sense to me. I copied it directly from the paper, but at the time there was a As a side note, you may want to check out LFQ if you're interested in the levels = 2 case in particular. |
@mueller-franzes @sekstini maybe we should just enforce odd levels for now for FSQ until a correction to the paper comes out? |
and yea, agreed with checking out LFQ. so many groups seeing success with it |
Other than the asymmetry being weird, levels > 2 seems fine (actually most of my code has been using even values). I think switching to (1 + eps) or enforcing levels > 2 makes sense. |
Thank you both for the super quick response! That's a good tip, I'll have a look at LFQ next. |
@sekstini oh you are right, it is a levels == 2 problem ok let's just go with |
TODO:
Notes:
bound
function) is missing. Just took copilot's suggestion for now.